#!/usr/bin/python

########################################################################
# vanityhash, a hex hash fragment creation tool
# Copyright (C) 2013 Ryan Finnie <ryan@finnie.org>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301, USA.
########################################################################

# Yay Python 3
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division
import sys
import hashlib
import struct
import multiprocessing
import time
import getopt
import select
import codecs
import zlib
import copy
import random


class VanityHash:
    """VanityHash class.

    Child subprocesses will have access to the instance of this class,
    forked from the master.
    """
    version = '2.0'

    # Whether stdin/stdout are binary (via the buffer interface).  This
    # is needed for compatibility with Python 2.6, 2.7 and 3.
    std_bin = False
    stdin = None
    stdout = None
    # VanityHashPoll object
    poll = None
    # hashlib context
    ctx = None
    # Dict of children's info
    children = {}
    # Total number of hashes searched, across all children
    total_searched = 0
    # Total number of hashes found, across all children
    total_found = 0
    # In append mode, whether the binary result has been printed yet
    printed_append = False
    # Epoch start time
    start_time = None
    # Next epoch time to display the progress
    next_progress_time = None
    # Number of real workers to be used
    workers_real = 0
    # List of zero-indexed workers to be used
    workers_real_l = []
    # Total workers in the worker set
    workers_total = 0
    # Proporition of real to total workers
    workers_real_fraction = 0.0
    # Pack type of hash candidates
    pack_type = b'=L'

    # Whether to display human-readable information to stderr
    opt_quiet = False
    # Search space, in bits
    opt_bits = 24
    # Total size containing the search space, in bits
    opt_bits_pack = 0
    # Endianness of the built container
    opt_byte_order = 'native'
    # Worker specification provided by the user
    opt_workers_s = 'guess'
    # Hash digest type
    opt_digest_type = 'md5'
    # How often to display progress information
    opt_progress_interval = 5.0
    # Whether to output the original data + the first result
    opt_append = False
    # Zero-indexed position within the hash to search
    opt_find_pos = 0
    # Whether to find the desired fragment anywhere in the hash
    opt_find_any_pos = False
    # Where to add a zeroed pack in append mode, if no match is found
    opt_append_empty = False

    def __init__(self):
        """Initialize the class."""
        # Make stdin/stdout binary, if possible.
        if hasattr(sys.stdin, 'buffer'):
            self.std_bin = True
            self.stdin = sys.stdin.buffer
            self.stdout = sys.stdout.buffer
        else:
            self.std_bin = False
            self.stdin = sys.stdin
            self.stdout = sys.stdout

        # Parse getopts.
        try:
            self.parse_options()
        except RuntimeError as e:
            self.usage()
            self.log('Error: %s' % e)
            sys.exit(2)

    def main(self):
        """Main program loop."""
        # Read stdin data.
        self.read_data()

        if self.opt_find_any_pos:
            self.log('Searching for %s at any position in a %d-bit space.' % (
                self.opt_find, self.opt_bits
            ))
        else:
            self.log('Searching for %s at position %d in a %d-bit space.' % (
                self.opt_find, self.opt_find_pos, self.opt_bits
            ))

        self.start_time = time.time()

        if self.workers_total == self.workers_real:
            self.log('Spawning %d worker%s...' % (
                self.workers_real, ((self.workers_real != 1) and 's' or '')
            ), newline=False)
        else:
            self.log('Spawning %d of %d worker%s (%s)...' % (
                self.workers_real, self.workers_total, ((self.workers_total != 1) and 's' or ''),
                (','.join(str(x + 1) for x in self.workers_real_l))
            ), newline=False)

        # Spawn worker children.
        self.poll = VanityHashPoller()
        for i in self.workers_real_l:
            parent_conn, child_conn = multiprocessing.Pipe(duplex=False)
            p = multiprocessing.Process(target=self.worker, args=(i, child_conn))
            p.name = 'Worker %d' % (i + 1)
            p.start()
            child_conn.close()
            self.poll.register(parent_conn)
            self.children[i] = {
                'process': p,
                'pipe': parent_conn,
            }

        self.log('done.')

        # Loop through messages from children, occasionally reporting
        # hashing progress.
        while True:
            try:
                self.query_children()
                self.report_progress()
            except KeyboardInterrupt:
                self.kill_children()

            if len(self.children) == 0:
                break

        if self.opt_append and self.opt_append_empty and not self.printed_append:
            self.stdout.write(b'\x00' * int(self.opt_bits_pack / 8))
            self.stdout.flush()
            self.printed_append = True

        # Final statistics.
        delta_time = time.time() - self.start_time
        self.log('Search finished in %02d:%02d, %d match%s found in %d%% of a %d-bit space.' % (
            (delta_time / 60), (delta_time % 60), self.total_found,
            ((self.total_found != 1) and 'es' or ''),
            ((self.total_searched - 1) / self.space_max * 100), self.opt_bits
        ))

    def worker(self, num_begin, conn):
        """Process hash instructions in a subprocess.

        Note that the state of the class instance is the state at the
        time of the fork from the parent process.  Communication back to
        the parent is done by the conn socket.
        """
        # Close other worker fds.  We don't need them.
        for childid in list(self.children):
            childdata = self.children[childid]
            childdata['pipe'].close()
            del(self.children[childid])

        to_find = self.opt_find
        to_find_len = len(to_find)
        find_pos = self.opt_find_pos
        find_anypos = self.opt_find_any_pos
        find_pos_end = find_pos + to_find_len

        # Start out with a group of 10,000 hashes.  This will be revised
        # to be approximately 2 seconds worth of hashes.
        last_report = time.time()
        report_i = 10000
        i = num_begin

        while i <= self.space_max:
            # Take into account multiple workers when determining when
            # to end the group.
            group_num_end = i + (report_i * self.workers_total)
            if group_num_end > self.space_max:
                group_num_end = self.space_max
            group_i_begin = i

            # The actual hash->test loop is as tight as possible, and
            # hence is duplicated a bit.
            if find_anypos:
                while i <= group_num_end:
                    ctxcopy = self.ctx.copy()
                    ctxcopy.update(struct.pack(self.pack_type, i))
                    hexdigest = ctxcopy.hexdigest()
                    if hexdigest.find(to_find) > -1:
                        conn.send(('FOUND', (hexdigest, i)))
                    i += self.workers_total
            else:
                while i <= group_num_end:
                    ctxcopy = self.ctx.copy()
                    ctxcopy.update(struct.pack(self.pack_type, i))
                    hexdigest = ctxcopy.hexdigest()
                    if hexdigest[find_pos:find_pos_end] == to_find:
                        conn.send(('FOUND', (hexdigest, i)))
                    i += self.workers_total

            # Figure out how many hashes were performed, and update the
            # parent.
            report_i = (i - group_i_begin) / self.workers_total
            conn.send(('PROGRESS', report_i))

            # Figure out how many hashes are needed to run for the next
            # ~2 seconds.
            now = time.time()
            next_report_i = int(2 * (report_i / (now - last_report)))
            last_report = now
            report_i = next_report_i

        conn.send(('DONE', ()))
        time.sleep(10)

    def bytes_to_hex(self, b):
        """Return a hex representation of binary byte data."""
        return codecs.encode(b, 'hex_codec').decode('ascii')

    def parse_options(self):
        """Parse and validate command-line options."""
        try:
            opts, args = getopt.getopt(sys.argv[1:], '?b:w:d:s:at:yqp:n:e', [
                'help', 'bits=', 'workers=', 'digest=', 'progress=', 'append',
                'bits-pack=', 'any-position', 'quiet', 'position=',
                'byte-order=', 'append-empty'
            ])
        except getopt.GetoptError as err:
            raise RuntimeError(str(err))
        for o, a in opts:
            if o in ('?', '--help'):
                self.usage()
                sys.exit(2)
            elif o in ('-b', '--bits'):
                self.opt_bits = int(a)
                if (self.opt_bits < 1) or (self.opt_bits > 64):
                    raise RuntimeError('Search space must be 64 bits or less')
            elif o in ('-t', '--bits-pack'):
                self.opt_bits_pack = int(a)
            elif o in ('-w', '--workers'):
                self.opt_workers_s = a
            elif o in ('-d', '--digest'):
                if a == 'sha1alt':
                    a = 'sha1'
                self.opt_digest_type = a
            elif o in ('-s', '--progress'):
                self.opt_progress_interval = float(a)
            elif o in ('-a', '--append'):
                self.opt_append = True
            elif o in ('-p', '--position'):
                self.opt_find_pos = int(a)
            elif o in ('-y', '--any-position'):
                self.opt_find_any_pos = True
            elif o in ('-q', '--quiet'):
                self.opt_quiet = True
            elif o in ('-n', '--byte-order'):
                if a in ('native', 'little', 'big'):
                    self.opt_byte_order = a
                else:
                    raise RuntimeError('Invalid byte order, must be one of: %s' % str(('native', 'little', 'big')))
            elif o in ('-e', '--append-empty'):
                self.opt_append_empty = True
            else:
                assert False, 'unhandled option %s' % o

        if len(args) < 1:
            self.usage()
            sys.exit(2)
        self.opt_find = args[0].lower()

        # Generate the container size if not specified.
        if self.opt_bits_pack == 0:
            self.opt_bits_pack = 1
            while self.opt_bits_pack < self.opt_bits:
                self.opt_bits_pack *= 2
            if self.opt_bits_pack < 8:
                self.opt_bits_pack = 8
        # Validate the container size.
        if (self.opt_bits_pack < self.opt_bits) or (self.opt_bits_pack > 64):
            raise RuntimeError('Invalid bits-pack')
        # Make sure the container size is a power of 2.
        bits_pack_bytes = int(self.opt_bits_pack / 8)
        if not (bits_pack_bytes & (bits_pack_bytes - 1)) == 0:
            raise RuntimeError('Invalid bits-pack')

        # Validate the desired hex fragment
        for i in self.opt_find:
            if not i in '0 1 2 3 4 5 6 7 8 9 a b c d e f'.split():
                raise RuntimeError('Invalid search hex string')

        # Pre-compute the largest integer to be tested.
        self.space_max = 0
        for i in range(0, self.opt_bits):
            self.space_max += 2 ** i

        # Build a pack type based on the bits_pack size.
        if self.opt_byte_order == 'little':
            self.pack_type = b'<'
        elif self.opt_byte_order == 'big':
            self.pack_type = b'>'
        else:
            self.pack_type = b'='
        if self.opt_bits_pack == 64:
            self.pack_type += b'Q'
        elif self.opt_bits_pack == 32:
            self.pack_type += b'L'
        elif self.opt_bits_pack == 16:
            self.pack_type += b'H'
        else:
            self.pack_type += b'B'

        # Build the worker options.
        if self.opt_workers_s == 'guess':
            try:
                self.opt_workers_s = str(multiprocessing.cpu_count())
            except NotImplementedError:
                self.opt_workers_s = str(1)
        if self.opt_workers_s.isdigit():
            # If a single number is given, the real and total workers
            # are the same.
            self.workers_total = int(self.opt_workers_s)
            self.workers_real_l = range(self.workers_total)
        else:
            # If a specification is given, validate and build according
            # to the specification.
            try:
                (workert, workerx) = self.opt_workers_s.split(':')
            except ValueError:
                raise RuntimeError('Invalid worker specification')
            self.workers_total = int(workert)
            for i in workerx.split(','):
                if not i.isdigit():
                    raise RuntimeError('Invalid worker specification')
                i = int(i)
                if (i > self.workers_total) or (i < 1):
                    raise RuntimeError('Invalid worker specification')
                if not (i - 1) in self.workers_real_l:
                    self.workers_real_l.append(i - 1)
                self.workers_real_l.sort()
        self.workers_real = len(self.workers_real_l)
        if (self.workers_total < 1) or (self.workers_real < 1):
            raise RuntimeError('Invalid number of workers')
        if self.workers_real > 128:
            raise RuntimeError('Cannot be more than 128 workers')
        self.workers_real_fraction = self.workers_real / self.workers_total

        # Test the hash type is valid.
        try:
            testctx = ExtendedHashlib().new(self.opt_digest_type)
        except ValueError:
            raise RuntimeError('Invalid digest type')

        # Test the position specified is correct according to the given
        # hash type.
        hexdigestsize = testctx.digest_size * 2
        maxpos = hexdigestsize - len(self.opt_find)
        if self.opt_find_pos < 0:
            self.opt_find_pos += hexdigestsize
        if self.opt_find_pos > maxpos:
            raise RuntimeError('Pattern position %d goes beyond end of %s digest' % (
                self.opt_find_pos, self.opt_digest_type.upper())
            )

    def log(self, text='', newline=True):
        """Write text to stderr."""
        if self.opt_quiet:
            return
        if newline:
            print(text, file=sys.stderr)
        else:
            print(text, file=sys.stderr, end='')

    def query_children(self):
        """Check the poll object for available children messages."""
        available_fds = self.poll.poll(1000)
        for childid in list(self.children):
            # The child may be killed during the loop
            if not childid in self.children:
                continue
            childdata = self.children[childid]
            p = childdata['process']
            if childdata['pipe'].fileno() in available_fds:
                msg = childdata['pipe'].recv()
                self.process_message(childid, msg)
            if not p.is_alive():
                if not p.exitcode == 0:
                    self.log('Worker %d (pid %d) died with exit status %d' % (childid + 1, p.pid, p.exitcode))
                self.poll.unregister(childdata['pipe'])
                childdata['pipe'].close()
                del(self.children[childid])

    def process_message(self, childid, msg):
        """Parse a received child message."""
        if msg[0] == 'PROGRESS':
            self.total_searched += msg[1]
        elif msg[0] == 'DONE':
            self.kill_children([childid])
        elif msg[0] == 'FOUND':
            (msgdigest, msgdata) = msg[1]
            msgdata = struct.pack(self.pack_type, msgdata)
            self.log('Match found: 0x%s -> %s %s' % (
                self.bytes_to_hex(msgdata), self.opt_digest_type.upper(), msgdigest)
            )
            self.total_found += 1
            if self.opt_append:
                if not self.printed_append:
                    self.stdout.write(msgdata)
                    self.stdout.flush()
                    self.printed_append = True
                    self.kill_children()
            else:
                if self.std_bin:
                    self.stdout.write(bytearray('%s %s\n' % (self.bytes_to_hex(msgdata), msgdigest), 'ascii'))
                else:
                    self.stdout.write('%s %s\n' % (self.bytes_to_hex(msgdata), msgdigest))
                self.stdout.flush()

    def read_data(self):
        """Read data from stdin and build the initial hash context."""
        self.ctx = ExtendedHashlib().new(self.opt_digest_type)
        self.log('Reading input data and adding to digest...', newline=False)
        datalen = 0
        while True:
            buf = self.stdin.read(1024)
            if not buf:
                break
            if self.opt_append:
                self.stdout.write(buf)
            datalen += len(buf)
            self.ctx.update(buf)
        if self.opt_append:
            self.stdout.flush()
        self.log('done.')

        origdigest = self.ctx.copy().hexdigest()
        self.log('Original data: %d bytes, %s %s' % (datalen, self.opt_digest_type.upper(), origdigest))

    def report_progress(self):
        """Occasionally output progress statistics."""
        if not self.next_progress_time:
            self.next_progress_time = self.start_time + self.opt_progress_interval

        now = time.time()
        if not now > self.next_progress_time:
            return
        elapsed = now - self.start_time
        percent = self.total_searched / (self.space_max * self.workers_real_fraction) * 100
        if self.total_searched > 0:
            remaining = (
                (self.space_max * self.workers_real_fraction - self.total_searched) / (self.total_searched / elapsed)
            )
            self.log('%3d%% searched, ~%02d:%02d remaining...' % (
                percent, (remaining / 60), (remaining % 60))
            )
        else:
            self.log('%3d%% searched...' % percent)
        self.next_progress_time = now + self.opt_progress_interval

    def kill_children(self, childids=None):
        """Kill all child subprocesses."""
        if not childids:
            childids = list(self.children)
        for childid in childids:
            childdata = self.children[childid]
            childdata['process'].terminate()
            self.poll.unregister(childdata['pipe'])
            childdata['pipe'].close()
            del(self.children[childid])

    def usage(self):
        """Output usage information."""
        self.log('vanityhash version %s' % self.version)
        self.log('Copyright (C) 2013 Ryan Finnie <ryan@finnie.org>')
        self.log()
        self.log('Usage:')
        self.log()
        self.log('    vanityhash [ options ] hexfragment < inputfile')
        self.log('    vanityhash --append [ options ] hexfragment < inputfile > outputfile')
        self.log()


class VanityHashPoller:
    """VanityHashPoller class.

    This class emulates the basic functionality of select.poll(), but
    will fall back to manually tracking polling on operating systems
    where select.poll() is not available.

    Note that this is not a drop-in replacement for select.poll().  For
    example, register/unregister take filehandles, not file descriptors,
    but they return file descriptors.
    """
    objs = []
    selectpoll = None

    def __init__(self):
        """Initialize and test for select.poll()."""
        if hasattr(select, 'poll'):
            self.selectpoll = select.poll()
        else:
            self.selectpoll = None

    def register(self, obj):
        """Register a filehandle."""
        assert obj not in self.objs
        self.objs.append(obj)
        if self.selectpoll:
            self.selectpoll.register(obj.fileno(), select.POLLIN)

    def unregister(self, obj):
        """Unregister a filehandle."""
        assert obj in self.objs
        self.objs.remove(obj)
        if self.selectpoll:
            self.selectpoll.unregister(obj.fileno())

    def poll(self, timeout=None):
        """Poll the registered filehandles.

        If timeout is None (default), the method will block until there
        are file descriptors to return.
        """
        if self.selectpoll:
            pollres = self.selectpoll.poll(timeout)
            if not pollres:
                return []
            available_fds = []
            for (fd, event) in pollres:
                if not event == select.POLLIN:
                    continue
                available_fds.append(fd)
            return available_fds
        else:
            while True:
                available_fds = []
                for obj in self.objs:
                    if obj.closed:
                        continue
                    if obj.poll():
                        available_fds.append(obj.fileno())
                if len(available_fds) > 0:
                    return available_fds
                else:
                    if timeout is None:
                        time.sleep(1)
                    else:
                        time.sleep(float(timeout) / 1000)
                        return []


class HashlibCRC32:
    """hashlib-compatible CRC32."""
    _crc = 0
    name = 'crc32'
    digestsize = 4
    digest_size = 4
    block_size = 1

    def copy(self):
        return copy.copy(self)

    def update(self, data):
        self._crc = zlib.crc32(data, self._crc)

    def digest(self):
        return struct.pack(b'>I', (self._crc & 0xffffffff))

    def hexdigest(self):
        return codecs.encode(self.digest(), 'hex_codec')


class HashlibAdler32:
    """hashlib-compatible Adler-32."""
    _checksum = 1
    name = 'adler32'
    digestsize = 4
    digest_size = 4
    block_size = 1

    def copy(self):
        return copy.copy(self)

    def update(self, data):
        self._checksum = zlib.adler32(data, self._checksum)

    def digest(self):
        return struct.pack(b'>I', (self._checksum & 0xffffffff))

    def hexdigest(self):
        return codecs.encode(self.digest(), 'hex_codec')


class HashlibRandom:
    """hashlib-compatible dummy random module."""
    _checksum = 0
    name = 'random'
    digestsize = 32
    digest_size = 32
    block_size = 64

    def __init__(self, digest_size=32, block_size=64):
        self.digestsize = digest_size
        self.digest_size = digest_size
        self.block_size = block_size
        self.update('')

    def copy(self):
        return copy.copy(self)

    def update(self, data):
        self._checksum = bytearray([random.randint(0, 255) for x in range(self.digest_size)])

    def digest(self):
        return bytes(self._checksum)

    def hexdigest(self):
        return codecs.encode(self.digest(), 'hex_codec')


class ExtendedHashlib:
    """hashlib-compatible extension system."""
    extended_algorithms = {
        'random': HashlibRandom,
        'crc32': HashlibCRC32,
        'adler32': HashlibAdler32
    }

    def __init__(self):
        if hasattr(hashlib, 'algorithms'):
            self.algorithms = (
                hashlib.algorithms +
                tuple(self.extended_algorithms.keys()))
            self._hashlib_algorithms = hashlib.algorithms
            for algo in hashlib.algorithms:
                vars(self)[algo] = getattr(hashlib, algo)
        elif hasattr(hashlib, 'algorithms_available'):
            self.algorithms_available = set(
                tuple(hashlib.algorithms_available) +
                tuple(self.extended_algorithms.keys()))
            self._hashlib_algorithms = hashlib.algorithms_available
            self.algorithms_guaranteed = hashlib.algorithms_guaranteed
            for algo in hashlib.algorithms_guaranteed:
                vars(self)[algo] = getattr(hashlib, algo)

    def new(self, algo, **kwargs):
        if algo in self.extended_algorithms:
            return self.extended_algorithms[algo](**kwargs)
        else:
            return hashlib.new(algo, **kwargs)


if __name__ == '__main__':
    vh = VanityHash()
    vh.main()
